notebooks/Unit 9 - Model Optimization/edge_tpu.ipynb (1,015 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "hRTa3Ee15WsJ"
},
"source": [
"# Retrain a classification model for Edge TPU using post-training quantization (with TF2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TaX0smDP7xQY"
},
"source": [
"In this guide, we will utilize TensorFlow 2 to construct an image classification model, train it using a cat & dog dataset, and convert it into TensorFlow Lite through post-training quantization. The model is built upon a pre-trained MobileNet V2, where we initially retrain only the classification layers while reusing the pre-trained feature extractor layers. Subsequently, we fine-tune the model by updating weights in specific feature extractor layers, a faster approach compared to training the entire model from the ground up. Following the training, we employ post-training quantization to convert all parameters to int8 format, reducing the model size and enhancing inference speed. This int8 format is crucial for compatibility with the Edge TPU found in Coral devices.\n",
"\n",
"Refer to the [coral.ai documentation](https://coral.ai/docs/edgetpu/models-intro/) for additional information on creating Edge TPU-compatible models. It's important to note that this tutorial requires TensorFlow 2.3 or later for full quantization, and it specifically expects a Keras-built model. This conversion strategy is not compatible with models imported from a frozen graph. If using TensorFlow 1.x, you can refer to [the 1.x version of this tutorial](https://colab.research.google.com/github/google-coral/tutorials/blob/master/retrain_classification_ptq_tf1.ipynb)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GTCYQg_be8C0"
},
"source": [
"## Import the required libraries"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "02MxhCyFmpzn"
},
"source": [
"To perform quantization on both input and output tensors, you must utilize the `TFLiteConverter` APIs, which are accessible in TensorFlow versions equal to or greater than 2.3."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iBMcobPHdD8O"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v77rlkCKW0IJ"
},
"source": [
"## Prepare the training data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j4QOy2uA3P_p"
},
"source": [
"Initially, we will download and arrange the cat & dog dataset intended for retraining the model. It is essential to pay close attention to this section for future replication with your custom image dataset. Specifically, observe that the \"PATH\" directory is organized with directories bearing appropriate names for each class. The provided code randomizes and partitions the photos into training and validation sets, creating a labels file derived from the names of the photo folders."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xxL2mjVVGIrV",
"outputId": "7d0705ce-68fb-4365-d20e-a5238b8bfb44"
},
"outputs": [],
"source": [
"_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'\n",
"path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)\n",
"PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')\n",
"\n",
"train_dir = os.path.join(PATH, 'train')\n",
"validation_dir = os.path.join(PATH, 'validation')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "e1LBOHlOqL6X",
"outputId": "7a1f4a66-9b69-4a3c-8fcc-1fe8f536c63f"
},
"outputs": [],
"source": [
"PATH"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z4gTv7ig2vMh"
},
"source": [
"Next, we use [`ImageDataGenerator`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator) to rescale the image data into float values (divide by 255 so the tensor values are between 0 and 1), and call `flow_from_directory()` to create two generators: one for the training dataset and one for the validation dataset.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fwEKSBGz0mse",
"outputId": "820b407e-272f-41eb-c39d-68fc2f8638e5"
},
"outputs": [],
"source": [
"IMAGE_SIZE = 224\n",
"BATCH_SIZE = 64\n",
"\n",
"train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)\n",
"val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)\n",
"\n",
"train_generator = train_datagen.flow_from_directory(\n",
" train_dir,\n",
" target_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
" batch_size=BATCH_SIZE,\n",
" class_mode='categorical' # Change if you have a different problem (e.g., binary)\n",
")\n",
"\n",
"val_generator = val_datagen.flow_from_directory(\n",
" validation_dir,\n",
" target_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
" batch_size=BATCH_SIZE,\n",
" class_mode='categorical' # Change if you have a different problem (e.g., binary)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VePDZC5Bh2mO"
},
"source": [
"\n",
"On each iteration, these generators provide a batch of images by reading images from disk and processing them to the proper tensor size (224 x 224). The output is a tuple of (images, labels). For example, you can see the shapes here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tx1L7fxxWA_G",
"outputId": "b70acbde-d8fa-43e4-def2-2bd90812ee98"
},
"outputs": [],
"source": [
"image_batch, label_batch = next(val_generator)\n",
"image_batch.shape, label_batch.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZrFFcwUb3iK9"
},
"source": [
"Now save the class labels to a text file:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-QFZIhWs4dsq",
"outputId": "867279e9-e50c-4e97-a68f-5178189e1d9b"
},
"outputs": [],
"source": [
"print (train_generator.class_indices)\n",
"\n",
"labels = '\\n'.join(sorted(train_generator.class_indices.keys()))\n",
"\n",
"with open('catdogs_labels.txt', 'w') as f:\n",
" f.write(labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "duxD_UDSOmng",
"outputId": "ef2b2087-f10b-410e-ab66-cc8cb56f428c"
},
"outputs": [],
"source": [
"!cat catdogs_labels.txt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OkH-kazQecHB"
},
"source": [
"## Build the model\n",
"\n",
"Now, we will build a model designed for transfer learning specifically on the final fully-connected layer.\n",
"\n",
"Commencing with Keras' MobileNet V2 as the foundational model, which is initially pre-trained on the ImageNet dataset (encompassing training for the recognition of 1,000 classes). This offers us an excellent feature extractor for image classification, enabling us to subsequently train a novel classification layer tailored to our flowers dataset.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "19IQ2gqneqmS",
"outputId": "2663ad3b-0355-4dce-8963-ff23a7056308"
},
"outputs": [],
"source": [
"IMG_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)\n",
"\n",
"# Create the base model from the pre-trained MobileNet V2\n",
"base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,\n",
" include_top=False,\n",
" weights='imagenet')\n",
"base_model.trainable = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eApvroIyn1K0"
},
"outputs": [],
"source": [
"model = tf.keras.Sequential([\n",
" base_model,\n",
" tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation='relu'),\n",
" tf.keras.layers.Dropout(0.2),\n",
" tf.keras.layers.GlobalAveragePooling2D(),\n",
" tf.keras.layers.Dense(units=2, activation='softmax')\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "g0ylJXE_kRLi"
},
"source": [
"### Configure the model\n",
"\n",
"Although this method is called `compile()`, it's basically a configuration step that's required before we can start training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RpR8HdyMhukJ"
},
"outputs": [],
"source": [
"model.compile(optimizer='adam',\n",
" loss='categorical_crossentropy',\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "I8ARiyMFsgbH",
"outputId": "e9a873c7-5a14-4b33-8035-74c05db349bf"
},
"outputs": [],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "krvBumovycVA",
"outputId": "3b485253-0a7a-4c64-ee0c-358cdf1e59f0"
},
"outputs": [],
"source": [
"print('Number of trainable weights = {}'.format(len(model.trainable_weights)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RxvgOYTDSWTx"
},
"source": [
"## Train the model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kBRNaOCCoA-P"
},
"source": [
"Now we can train the model using data provided by the `train_generator` and `val_generator` that we created at the beginning.\n",
"\n",
"This should take less than 10 minutes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JsaRFlZ9B6WK",
"outputId": "65bc2ad9-dd99-4cb0-e8a2-4c1268c13aed"
},
"outputs": [],
"source": [
"history = model.fit(train_generator,\n",
" steps_per_epoch=len(train_generator),\n",
" epochs=10,\n",
" validation_data=val_generator,\n",
" validation_steps=len(val_generator))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hd94CKImf8vi"
},
"source": [
"### Plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 718
},
"id": "53OTCh3jnbwV",
"outputId": "1a094684-31ad-48a6-d6df-535ee1fe03f7"
},
"outputs": [],
"source": [
"acc = history.history['accuracy']\n",
"val_acc = history.history['val_accuracy']\n",
"\n",
"loss = history.history['loss']\n",
"val_loss = history.history['val_loss']\n",
"\n",
"plt.figure(figsize=(8, 8))\n",
"plt.subplot(2, 1, 1)\n",
"plt.plot(acc, label='Training Accuracy')\n",
"plt.plot(val_acc, label='Validation Accuracy')\n",
"plt.legend(loc='lower right')\n",
"plt.ylabel('Accuracy')\n",
"plt.ylim([min(plt.ylim()),1])\n",
"plt.title('Training and Validation Accuracy')\n",
"\n",
"plt.subplot(2, 1, 2)\n",
"plt.plot(loss, label='Training Loss')\n",
"plt.plot(val_loss, label='Validation Loss')\n",
"plt.legend(loc='upper right')\n",
"plt.ylabel('Cross Entropy')\n",
"plt.ylim([0,1.0])\n",
"plt.title('Training and Validation Loss')\n",
"plt.xlabel('epoch')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CqwV-CRdS6Nv"
},
"source": [
"## Fine tune the base model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dBTEEnxv9X6J"
},
"source": [
"So far, the training focused on the classification layers, and the weights of the pre-trained network remained *unchanged*.\n",
"\n",
"To improve accuracy, one approach involves training or \"fine-tuning\" additional layers from the pre-trained model. In other words, we'll unfreeze specific layers from the base model and modify their weights, originally trained on 1,000 ImageNet classes, to better align with features present in our cat & dog dataset."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPXnzUK0QonF"
},
"source": [
"### Un-freeze more layers\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rfxv_ifotQak"
},
"source": [
"So instead of freezing the entire base model, we'll freeze individual layers.\n",
"\n",
"First, let's see how many layers are in the base model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4nzcagVitLQm",
"outputId": "a3f56ff7-bae8-4829-a6bb-489d218e66a7"
},
"outputs": [],
"source": [
"print(\"Number of layers in the base model: \", len(base_model.layers))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dGcXdaQqASlC"
},
"source": [
"Let's try freezing just the bottom 100 layers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-4HgVAacRs5v"
},
"outputs": [],
"source": [
"base_model.trainable = True\n",
"fine_tune_at = 100\n",
"\n",
"# Freeze all the layers before the `fine_tune_at` layer\n",
"for layer in base_model.layers[:fine_tune_at]:\n",
" layer.trainable = False"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4Uk1dgsxT0IS"
},
"source": [
"### Reconfigure the model\n",
"\n",
"Now configure the model again, but this time with a lower learning rate (the default is 0.001)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NtUnaz0WUDva"
},
"outputs": [],
"source": [
"model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),\n",
" loss='categorical_crossentropy',\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WwBWy7J2kZvA",
"outputId": "85867242-fd47-4e61-a173-b3c06d719039"
},
"outputs": [],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bNXelbMQtonr",
"outputId": "693ffaf8-e0cd-49ae-b12e-d39e9c8c82e0"
},
"outputs": [],
"source": [
"print('Number of trainable weights = {}'.format(len(model.trainable_weights)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4G5O4jd6TuAG"
},
"source": [
"### Continue training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bppmJTmDpXtK"
},
"source": [
"Now let's fine-tune all trainable layers. This starts with the weights we already trained in the classification layers, so we don't need as many epochs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PiXbLb1O8IDy",
"outputId": "2f0e53e9-27fd-4a4d-a380-ff0c86e80840"
},
"outputs": [],
"source": [
"history_fine = model.fit(train_generator,\n",
" steps_per_epoch=len(train_generator),\n",
" epochs=5,\n",
" validation_data=val_generator,\n",
" validation_steps=len(val_generator))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xqIjZvhBBJNn"
},
"source": [
"### New Plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 718
},
"id": "chW103JUItdk",
"outputId": "03b3f91d-08c8-4824-d9a5-6d3a1083866a"
},
"outputs": [],
"source": [
"acc = history_fine.history['accuracy']\n",
"val_acc = history_fine.history['val_accuracy']\n",
"\n",
"loss = history_fine.history['loss']\n",
"val_loss = history_fine.history['val_loss']\n",
"\n",
"plt.figure(figsize=(8, 8))\n",
"plt.subplot(2, 1, 1)\n",
"plt.plot(acc, label='Training Accuracy')\n",
"plt.plot(val_acc, label='Validation Accuracy')\n",
"plt.legend(loc='lower right')\n",
"plt.ylabel('Accuracy')\n",
"plt.ylim([min(plt.ylim()),1])\n",
"plt.title('Training and Validation Accuracy')\n",
"\n",
"plt.subplot(2, 1, 2)\n",
"plt.plot(loss, label='Training Loss')\n",
"plt.plot(val_loss, label='Validation Loss')\n",
"plt.legend(loc='upper right')\n",
"plt.ylabel('Cross Entropy')\n",
"plt.ylim([0,1.0])\n",
"plt.title('Training and Validation Loss')\n",
"plt.xlabel('epoch')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kRDabW_u1wnv"
},
"source": [
"## Convert to TFLite"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hNvMl6CM6lG4"
},
"source": [
"Ordinarily, creating a TensorFlow Lite model is just a few lines of code with [`TFLiteConverter`](https://www.tensorflow.org/api_docs/python/tf/lite/TFLiteConverter). For example, this creates a basic (un-quantized) TensorFlow Lite model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "srOYhMYfx9XH",
"outputId": "cedd344d-3dd2-4904-e829-341239bd44b5"
},
"outputs": [],
"source": [
"converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
"tflite_model = converter.convert()\n",
"\n",
"with open('mobilenet_v2_1.0_224.tflite', 'wb') as f:\n",
" f.write(tflite_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P0_StAtZwJ5p"
},
"source": [
"However, this `.tflite` file still uses floating-point values for the parameter data, and we need to fully quantize the model to int8 format.\n",
"\n",
"To fully quantize the model, we need to perform [post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) with a representative dataset, which requires a few more arguments for the `TFLiteConverter`, and a function that builds a dataset that's representative of the training dataset.\n",
"\n",
"So let's convert the model again with post-training quantization:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "w9ydAmHGHUZl",
"outputId": "591f0479-a0df-4973-f654-f186f523f654"
},
"outputs": [],
"source": [
"# A generator that provides a representative dataset\n",
"def representative_data_gen():\n",
" dataset_list = tf.data.Dataset.list_files(PATH + '/train/*/*')\n",
" for i in range(100):\n",
" image = next(iter(dataset_list))\n",
" image = tf.io.read_file(image)\n",
" image = tf.io.decode_jpeg(image, channels=3)\n",
" image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])\n",
" image = tf.cast(image / 255., tf.float32)\n",
" image = tf.expand_dims(image, 0)\n",
" yield [image]\n",
"\n",
"converter = tf.lite.TFLiteConverter.from_keras_model(model)\n",
"# This enables quantization\n",
"converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
"# This sets the representative dataset for quantization\n",
"converter.representative_dataset = representative_data_gen\n",
"# This ensures that if any ops can't be quantized, the converter throws an error\n",
"converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]\n",
"# For full integer quantization, though supported types defaults to int8 only, we explicitly declare it for clarity.\n",
"converter.target_spec.supported_types = [tf.int8]\n",
"# These set the input and output tensors to uint8 (added in r2.3)\n",
"converter.inference_input_type = tf.uint8\n",
"converter.inference_output_type = tf.uint8\n",
"tflite_model = converter.convert()\n",
"\n",
"with open('mobilenet_v2_1.0_224_quant.tflite', 'wb') as f:\n",
" f.write(tflite_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RMLYBDe_e849"
},
"source": [
"### Compare the accuracy\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SFgbRx_Twd-P"
},
"source": [
"So now we have a fully quantized TensorFlow Lite model. To be sure the conversion went well, let's evaluate both the raw model and the TensorFlow Lite model.\n",
"\n",
"First check the accuracy of the raw model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RkQ2IlAWfC5O",
"outputId": "6ea57df1-4e58-4749-cb92-5a7a5c9302f3"
},
"outputs": [],
"source": [
"batch_images, batch_labels = next(val_generator)\n",
"\n",
"logits = model(batch_images)\n",
"prediction = np.argmax(logits, axis=1)\n",
"truth = np.argmax(batch_labels, axis=1)\n",
"\n",
"keras_accuracy = tf.keras.metrics.Accuracy()\n",
"keras_accuracy(prediction, truth)\n",
"\n",
"print(\"Raw model accuracy: {:.3%}\".format(keras_accuracy.result()))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hjx3dgZNwmKN"
},
"source": [
"Now let's check the accuracy of the `.tflite` file, using the same dataset.\n",
"\n",
"However, there's no convenient API to evaluate the accuracy of a TensorFlow Lite model, so this code runs several inferences and compares the predictions against ground truth:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iBs0O7q_wlCN",
"outputId": "cca2bbdc-310b-4448-f3b1-59f0aa69b079"
},
"outputs": [],
"source": [
"def set_input_tensor(interpreter, input):\n",
" input_details = interpreter.get_input_details()[0]\n",
" tensor_index = input_details['index']\n",
" input_tensor = interpreter.tensor(tensor_index)()[0]\n",
" # Inputs for the TFLite model must be uint8, so we quantize our input data.\n",
" # NOTE: This step is necessary only because we're receiving input data from\n",
" # ImageDataGenerator, which rescaled all image data to float [0,1]. When using\n",
" # bitmap inputs, they're already uint8 [0,255] so this can be replaced with:\n",
" # input_tensor[:, :] = input\n",
" scale, zero_point = input_details['quantization']\n",
" input_tensor[:, :] = np.uint8(input / scale + zero_point)\n",
"\n",
"def classify_image(interpreter, input):\n",
" set_input_tensor(interpreter, input)\n",
" interpreter.invoke()\n",
" output_details = interpreter.get_output_details()[0]\n",
" output = interpreter.get_tensor(output_details['index'])\n",
" # Outputs from the TFLite model are uint8, so we dequantize the results:\n",
" scale, zero_point = output_details['quantization']\n",
" output = scale * (output - zero_point)\n",
" top_1 = np.argmax(output)\n",
" return top_1\n",
"\n",
"interpreter = tf.lite.Interpreter('mobilenet_v2_1.0_224_quant.tflite')\n",
"interpreter.allocate_tensors()\n",
"\n",
"# Collect all inference predictions in a list\n",
"batch_prediction = []\n",
"batch_truth = np.argmax(batch_labels, axis=1)\n",
"\n",
"for i in range(len(batch_images)):\n",
" prediction = classify_image(interpreter, batch_images[i])\n",
" batch_prediction.append(prediction)\n",
"\n",
"# Compare all predictions to the ground truth\n",
"tflite_accuracy = tf.keras.metrics.Accuracy()\n",
"tflite_accuracy(batch_prediction, batch_truth)\n",
"print(\"Quant TF Lite accuracy: {:.3%}\".format(tflite_accuracy.result()))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WfM4kAPiPg9q"
},
"source": [
"You might see some, but hopefully not very much accuracy drop between the raw model and the TensorFlow Lite model. But again, these results are not suitable for production deployment."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZmiHICezwXZq"
},
"source": [
"## Compile for the Edge TPU\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DhOzAdzF3Dyk"
},
"source": [
"Finally, we're ready to compile the model for the Edge TPU.\n",
"\n",
"First download the [Edge TPU Compiler](https://coral.ai/docs/edgetpu/compiler/):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "p6ZpWgrk21Ad",
"outputId": "4ea89f0f-6e9e-4963-a243-8dd3b05d3896"
},
"outputs": [],
"source": [
"! curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -\n",
"\n",
"! echo \"deb https://packages.cloud.google.com/apt coral-edgetpu-stable main\" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list\n",
"\n",
"! sudo apt-get update\n",
"\n",
"! sudo apt-get install edgetpu-compiler"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mtPcYiER3Ymp"
},
"source": [
"Then compile the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "joxrIB0I3cdi",
"outputId": "fad773e7-9d68-4061-cd13-05cb7dddebed"
},
"outputs": [],
"source": [
"! edgetpu_compiler mobilenet_v2_1.0_224_quant.tflite"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7R8JMQc1MMm5"
},
"source": [
"That's it.\n",
"\n",
"The compiled model uses the same filename but with \"_edgetpu\" appended at the end."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Oi9-Voc8A7VK"
},
"source": [
"## Download the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XiugMm-jBbWl"
},
"source": [
"You can download the converted model and labels file from Colab like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x47uW_lI1DoV"
},
"outputs": [],
"source": [
"from google.colab import files\n",
"\n",
"files.download('mobilenet_v2_1.0_224_quant_edgetpu.tflite')\n",
"files.download('catdogs_labels.txt')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_qOCP3mXXvsm"
},
"source": [
"If you get a \"Failed to fetch\" error here, it's probably because the files weren't done saving. So just wait a moment and try again.\n",
"\n",
"Also look out for a browser popup that might need approval to download the files."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_TZTwG7nhm0C"
},
"source": [
"## Run the model on the Edge TPU\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RwywT4ZpQjLf"
},
"source": [
"You can now run the model on your Coral device with acceleration on the Edge TPU.\n",
"\n",
"To get started, try using your `.tflite` model with [this code for image classification with the TensorFlow Lite API](https://github.com/google-coral/tflite/tree/master/python/examples/classification).\n",
"\n",
"Just follow the instructions on that page to set up your device, copy the `mobilenet_v2_1.0_224_quant_edgetpu.tflite` and `flower_labels.txt` files to your Coral Dev Board or device with a Coral Accelerator, and pass it a flower photo like this:\n",
"\n",
"```\n",
"python3 classify_image.py \\\n",
" --model mobilenet_v2_1.0_224_quant_edgetpu.tflite \\\n",
" --labels flower_labels.txt \\\n",
" --input flower.jpg\n",
"```\n",
"\n",
"This notebook got reference from [coral.ai/examples](https://coral.ai/examples/#code-examples/). Check more example there"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Z22ABdF88xWu"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}